//
// Created by Pulsar on 2019/4/15.
//
#include <iostream>
#include <Eigen/Core>
#include <layers/ANNLayer.h>
#include <logic_functions/Loss.h>
#include <activate_functions/sigmod.h>
using namespace std;
Eigen::MatrixXd AnnLayer::query(Eigen::MatrixXd input) {
    std::vector<Eigen::MatrixXd> output=this->forword(input.transpose());
    return output[0];
}

int AnnLayer::train(Eigen::MatrixXd train_data_line,Eigen::MatrixXd label) {
    std::vector<Eigen::MatrixXd> output=this->forword(train_data_line);
    Eigen::MatrixXd final_output=output[0];
    Eigen::MatrixXd hidden_output=output[1];
    this->backword(final_output,hidden_output,train_data_line,label);
    return 0;
}
int AnnLayer::train(Eigen::MatrixXd train_data,Eigen::MatrixXd label,int epoch) {
    int lines=(int)train_data.rows();
    std::cout<<"train_data "<<train_data.rows()<<","<<train_data.cols()<<std::endl;
    std::cout<<"label "<<label.rows()<<","<<label.cols()<<std::endl;
    for(int i=0;i<epoch;i+=1){
        std::cout<<"======================epoch "<<i<<"======================"<<std::endl;
        double loss=0;
        for(int data_line=0;data_line<lines;data_line+=1){
            std::vector<Eigen::MatrixXd> output=this->forword(train_data.row(data_line).transpose());
            Eigen::MatrixXd final_output=output[0];
            Eigen::MatrixXd hidden_output=output[1];
            loss+=this->backword(final_output,hidden_output,train_data.row(data_line),label.row(data_line));
        }
        std::cout<<"======================loss: "<<loss/epoch<<"======================"<<std::endl;
    }
    std::cout<<"======================train finished======================"<<std::endl;
    return 0;
}

AnnLayer::~AnnLayer() {

}

double AnnLayer::backword(Eigen::MatrixXd final_outputs,Eigen::MatrixXd hidden_outputs,Eigen::MatrixXd inputs,Eigen::MatrixXd targets) {
//    cout<<"------------start backword---------------"<<endl;
    Eigen::MatrixXd output_errors_exp_1=1-final_outputs.array();
    Eigen::MatrixXd output_errors = final_outputs*output_errors_exp_1.transpose()*(targets.transpose()-final_outputs);
//    cout<<"final_outputs "<<final_outputs.rows()<<","<<final_outputs.cols()<<endl;
//    cout<<"output_errors_exp_1 "<<output_errors_exp_1.rows()<<","<<output_errors_exp_1.cols()<<endl;
//    cout<<"targets "<<targets.rows()<<","<<targets.cols()<<endl;
//    cout<<"output_errors "<<output_errors.rows()<<","<<output_errors.cols()<<endl;

    Eigen::MatrixXd hidden_errors_exp_1=(1-hidden_outputs.array());
    Eigen::MatrixXd hidden_errors = hidden_outputs*hidden_errors_exp_1.transpose()*this->weights_hidden_to_output.transpose()*output_errors;

    Eigen::MatrixXd weights_hidden_to_output_exp=hidden_outputs.transpose() * this->learning_rate;
    this->weights_hidden_to_output += output_errors * weights_hidden_to_output_exp;
    Eigen::MatrixXd baises_hidden_to_output_exp=output_errors  * this->learning_rate;
    this->baises_hidden_to_output += baises_hidden_to_output_exp;

    Eigen::MatrixXd weights_input_to_hidden_exp=inputs * this->learning_rate;
    this->weights_input_to_hidden += hidden_errors * weights_input_to_hidden_exp;
    Eigen::MatrixXd baises_input_to_hidden_exp=hidden_errors  * this->learning_rate;
    this->baises_input_to_hidden += baises_input_to_hidden_exp;


    Eigen::MatrixXd tmp_loss=this->loss_function(targets.transpose(),final_outputs);
    double loss = tmp_loss.sum();
    return loss;
}

std::vector<Eigen::MatrixXd> AnnLayer::forword(Eigen::MatrixXd inputs) {
//    cout<<"------------start forword---------------"<<endl;
//    cout<<"inputs "<<inputs.rows()<<","<<inputs.cols()<<endl;
//    cout<<"weights_input_to_hidden "<<weights_input_to_hidden.rows()<<","<<weights_input_to_hidden.cols()<<endl;
//    cout<<"baises_input_to_hidden "<<baises_input_to_hidden.rows()<<","<<baises_input_to_hidden.cols()<<endl;
    Eigen::MatrixXd hidden_inputs = this->weights_input_to_hidden*inputs+ this->baises_input_to_hidden;
    Eigen::MatrixXd hidden_outputs =  this->activation_function(hidden_inputs);

    Eigen::MatrixXd final_inputs = this->weights_hidden_to_output*inputs+this->baises_hidden_to_output ;
    Eigen::MatrixXd final_outputs =  this->activation_function(final_inputs) ;
//    cout<<"weights_hidden_to_output "<<weights_input_to_hidden.rows()<<","<<weights_input_to_hidden.cols()<<endl;
//    cout<<"baises_hidden_to_output "<<baises_hidden_to_output.rows()<<","<<baises_hidden_to_output.cols()<<endl;
//    cout<<"final_inputs "<<final_inputs.rows()<<","<<final_inputs.cols()<<endl;
//    cout<<"final_outputs "<<final_outputs.rows()<<","<<final_outputs.cols()<<endl;
//    cout<<"------------end forword---------------"<<endl;


    std::vector<Eigen::MatrixXd> outputs;
    outputs.push_back(final_outputs);
    outputs.push_back(hidden_outputs);
    return outputs;
}

AnnLayer::AnnLayer(double _learning_rate,int _input_nodes,int _hidden_nodes,int _output_nodes):learning_rate(_learning_rate),input_nodes(_input_nodes),hidden_nodes(_hidden_nodes),output_nodes(_output_nodes) {
    this->weights_input_to_hidden=Eigen::MatrixXd::Random(_hidden_nodes,_input_nodes);
    this->weights_hidden_to_output=Eigen::MatrixXd::Random(_output_nodes,_hidden_nodes);
    this->baises_input_to_hidden=Eigen::MatrixXd::Random(_hidden_nodes,1);
    this->baises_hidden_to_output=Eigen::MatrixXd::Random(_output_nodes,1);
    this->activation_function=sigmod;
    this->loss_function=loss_variance;
}
